{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0816 16:04:58.048052 4706751936 deprecation_wrapper.py:118] From /Users/lewtun/git/deep-math/notebooks/baselines/lstm.py:4: The name tf.keras.layers.CuDNNLSTM is deprecated. Please use tf.compat.v1.keras.layers.CuDNNLSTM instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-beta1\n",
      "GPU Available:  False\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pickle\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "from src.models.attention import LSTMWithAttention\n",
    "from lstm import LSTM_Simple\n",
    "from metrics import exact_match_metric, exact_match_metric_index\n",
    "from callbacks import NValidationSetsCallback, GradientLogger\n",
    "from src.models.generator import DataGenerator, DataGeneratorAttention\n",
    "from src.models.utils import get_sequence_data\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "print(tf.__version__)\n",
    "print(\"GPU Available: \", tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT = Path('../../')\n",
    "MODELS = ROOT/'models/'\n",
    "DATA = ROOT/'data/'\n",
    "SETTINGS = ROOT/'settings/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(str(SETTINGS/'settings_local.json'), 'r') as file:\n",
    "    settings_dict = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'math_module': 'arithmetic__add_sub',\n",
       " 'train_level': '*',\n",
       " 'batch_size': 1024,\n",
       " 'thinking_steps': 16,\n",
       " 'epochs': 1,\n",
       " 'num_encoder_units': 512,\n",
       " 'num_decoder_units': 2048,\n",
       " 'embedding_dim': 256,\n",
       " 'save_path': '/artifacts/',\n",
       " 'data_path': 'data/'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "settings_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(DATA/'processed/arithmetic__add_sub-*.pkl', 'rb') as file:\n",
    "    sequence_data = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_token_index', 'target_token_index', 'input_texts', 'target_texts', 'max_encoder_seq_length', 'max_decoder_seq_length', 'num_encoder_tokens', 'num_decoder_tokens', 'num_thinking_steps'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sequence_data.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define evaluation class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMWithAttentionEvaluator:\n",
    "    \n",
    "    def __init__(self, path):\n",
    "        \n",
    "        with open(str(path/'settings/settings_local.json'), 'r') as file:\n",
    "            self.settings_dict = json.load(file)\n",
    "\n",
    "        with open(str(path/'data/processed/arithmetic__add_sub-*.pkl'), 'rb') as file:\n",
    "            self.sequence_data = pickle.load(file) \n",
    "            self.token_index = self.sequence_data['input_token_index']\n",
    "\n",
    "        self.num_tokens = len(self.token_index)\n",
    "        \n",
    "        adam = Adam(lr=6e-4, beta_1=0.9, beta_2=0.995, epsilon=1e-9, decay=0.0, amsgrad=False, clipnorm=0.1)\n",
    "        self.lstm = LSTMWithAttention(sequence_data['num_encoder_tokens'], \n",
    "                         sequence_data['num_decoder_tokens'], \n",
    "                         sequence_data['max_encoder_seq_length'],\n",
    "                        sequence_data['max_decoder_seq_length'],\n",
    "                        settings_dict['num_encoder_units'],\n",
    "                        settings_dict['num_decoder_units'],\n",
    "                        settings_dict['embedding_dim'])\n",
    "        self.model = self.lstm.get_model()\n",
    "        self.model.load_weights(str(MODELS/'model.h5'))\n",
    "        self.model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=[exact_match_metric_index])\n",
    "        \n",
    "    def evaluate_model(self, input_texts, output_texts, teacher_forcing=True, batch_size=128, n_samples=1000):\n",
    "        max_seq_length  = max([len(txt_in)+len(txt_out) for txt_in, txt_out in zip(input_texts,output_texts)])\n",
    "        \n",
    "        data_gen_pars = {\n",
    "            \"batch_size\": settings_dict[\"batch_size\"],\n",
    "            \"max_encoder_seq_length\": sequence_data['max_encoder_seq_length'],\n",
    "            \"max_decoder_seq_length\": sequence_data['max_decoder_seq_length'],\n",
    "            \"num_encoder_tokens\": sequence_data['num_encoder_tokens'],\n",
    "            \"num_decoder_tokens\": sequence_data['num_decoder_tokens'],\n",
    "            \"input_token_index\": sequence_data['input_token_index'],\n",
    "            \"target_token_index\": sequence_data['target_token_index'],\n",
    "            \"num_thinking_steps\": settings_dict[\"thinking_steps\"],\n",
    "        }\n",
    "        \n",
    "        self.data_generator = DataGeneratorAttention(input_texts=input_texts,\n",
    "                                               target_texts=output_texts,\n",
    "                                               **data_gen_pars)\n",
    "        \n",
    "        if not teacher_forcing:\n",
    "            outputs_true, outputs_preds = self.predict_without_teacher(n_samples, max_seq_length)\n",
    "            exact_match = len([0 for out_true, out_preds in zip(outputs_true, outputs_preds) if out_true.strip()==out_preds.strip()])/len(outputs_true)\n",
    "        \n",
    "        else:\n",
    "            result = self.model.evaluate_generator(self.data_generator, verbose=1)\n",
    "            exact_match = result[1]\n",
    "            \n",
    "        return exact_match\n",
    "    \n",
    "    def predict_on_string(self, text, max_output_length=100):\n",
    "        \n",
    "        max_seq_length = len(text) + max_output_length\n",
    "\n",
    "        \n",
    "        params = {'batch_size': 1,\n",
    "                  'max_seq_length': max_seq_length,\n",
    "                  'num_tokens': self.num_tokens,\n",
    "                  'token_index': self.token_index,\n",
    "                  'num_thinking_steps': self.settings_dict[\"thinking_steps\"]\n",
    "                 }\n",
    "        \n",
    "        \n",
    "        self.data_generator = DataGeneratorSeq(input_texts=[text],\n",
    "                                               target_texts=['0'*max_output_length],\n",
    "                                               **params)\n",
    "        \n",
    "        outputs_true, outputs_preds = self.predict_without_teacher(1, max_seq_length)\n",
    "        \n",
    "        return outputs_preds[0].strip()\n",
    "\n",
    "    def predict_without_teacher(self, n_samples, max_seq_length, random=True):\n",
    "        \n",
    "        encoded_texts = [] \n",
    "        outputs_true = []\n",
    "        if random:\n",
    "            samples = np.random.choice(self.data_generator.indexes, n_samples, replace=False)\n",
    "        else:\n",
    "            samples = list(range(n_samples))\n",
    "        for i in samples:\n",
    "            input_len = len(input_texts_train[i])\n",
    "            sample = self.data_generator._DataGeneratorSeq__data_generation([i])         \n",
    "            input_len = len(self.data_generator.input_texts[i])\n",
    "            outputs_true.append(self.data_generator.target_texts[i])\n",
    "            x = sample[0][0][:input_len+self.settings_dict[\"thinking_steps\"]+1]\n",
    "            encoded_texts.append(np.expand_dims(x, axis=0))\n",
    "            \n",
    "        outputs_preds = self.lstm.decode_sample(encoded_texts, self.token_index, max_seq_length)\n",
    "        return outputs_true, outputs_preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0816 16:05:18.711623 4706751936 deprecation.py:323] From /Users/lewtun/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:3868: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "lstm_eval = LSTMWithAttentionEvaluator(ROOT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample=1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train', 'interpolate', 'extrapolate'])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sequence_data['input_texts'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 14s 14s/step - loss: 2.8174 - exact_match_metric_index: 0.8107\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.81069744"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(sequence_data['input_texts']['train'][:test_sample], sequence_data['target_texts']['train'][:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 12s 12s/step - loss: 7.1542 - exact_match_metric_index: 0.7504\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.75037444"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(sequence_data['input_texts']['interpolate'][:test_sample], sequence_data['target_texts']['interpolate'][:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 12s 12s/step - loss: 10.4590 - exact_match_metric_index: 0.6999\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.69989574"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_eval.evaluate_model(sequence_data['input_texts']['extrapolate'][:test_sample], sequence_data['target_texts']['extrapolate'][:test_sample])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
